{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Patch Phenotyping Tutorial\n",
    "\n",
    "This tutorial walks you to perform patch phenotyping using KRONOS. The workflow involves extracting patch embeddings from multiplex images and classifying patch phenotypes.\n",
    "\n",
    "## Prerequisites\n",
    "\n",
    "To follow this tutorial, ensure you have the following data prepared:\n",
    "\n",
    "1. **multiplex_images**: Dir with TIFF images with multiple markers (e.g., DAPI, CD markers, etc.).\n",
    "2. **marker_info_with_metadata.csv**: A CSV file containing:\n",
    "    - `channel_id`: Identifier for the image channel.\n",
    "    - `marker_name`: Name of the marker (e.g., DAPI, CD20).\n",
    "    - `marker_id`: Unique ID for the marker.\n",
    "    - `marker_mean`: Mean intensity value for normalization.\n",
    "    - `marker_std`: Standard deviation for normalization.\n",
    "3. **cell_masks**: Dir with grayscale images where:\n",
    "    - Each cell pixel is represented by a unique cell ID.\n",
    "    - Non-cell pixels are represented by `0`.\n",
    "4. **cell_annotations.csv**: A CSV file with entries:\n",
    "    - `image_name`: Name of the multiplex image.\n",
    "    - `cell_id`: Unique ID of the cell in the cell mask.\n",
    "    - `x_center`: X-coordinate of the cell center.\n",
    "    - `y_center`: Y-coordinate of the cell center.\n",
    "    - `cell_type_id`: Phenotype label of the cell.\n",
    "\n",
    "**Notes**:<br> \n",
    "- Refer to the **[1-Data-Download-And-Preprocessing](https://github.com/mahmoodlab/KRONOS/blob/main/tutorials/1%20-%20Data-Download-And-Preprocessing.ipynb)** tutorial to prepare data.\n",
    "<br>\n",
    "- cHL dataset does not have patch phenotype labels therefore, we will generate patch phenotype lables using cell level labels with majority vote.  \n",
    "\n",
    "\n",
    "## Overview\n",
    "\n",
    "This tutorial performs **patch-based phenotyping**, which involves the following steps:\n",
    "\n",
    "1. **Patch Extraction**: Extract patches from multiplex images using sliding window approach.\n",
    "2. **Feature Extraction**: Use the pre-trained KRONOS model to extract feature embeddings for each patch.\n",
    "3. **Classification**: Train a logistic regression model to classify patch phenotypes based on the extracted features.\n",
    "4. **Evaluation**: Evaluate the model using cross-validation and compute metrics such as F1-Score, Balanced Accuracy, and ROC AUC.\n",
    "\n",
    "By the end of this tutorial, you will have a complete pipeline for patch phenotyping using KRONOS features."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Experiment Configuration\n",
    "\n",
    "In this section, we define the configuration and hyperparameters required for the workflow. Ensure your dataset in the project directory is organized as follows:\n",
    "\n",
    "### Dataset Structure\n",
    "- **`dataset/multiplex_images/`**: Contains the multiplex image files (e.g., `.tiff`).\n",
    "- **`dataset/cell_masks/`**: Contains the cell mask files corresponding to the multiplex images.\n",
    "- **`dataset/marker_info_with_metadata.csv`**: A CSV file containing marker metadata.\n",
    "- **`dataset/cell_annotations.csv`**: A CSV file containing ground truth annotations.\n",
    "\n",
    "### Output Directories\n",
    "The following directories in the project directory will be generated during the workflow:\n",
    "- **`patches/`**: Stores the extracted cell-centered patches in `.h5` format.\n",
    "- **`features/`**: Stores the extracted features in `.npy` format.\n",
    "- **`folds/`**: Contains cross-validation folds for training, validation, and testing.\n",
    "- **`results/`**: Stores the results for each fold and aggregated metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import PatchPhenotyping \n",
    "\n",
    "# Define the root directory for the project\n",
    "project_dir = \"./chl_maps_dataset/\"  # Replace with your actual root directory\n",
    "\n",
    "# Configuration dictionary containing all parameters for the pipeline\n",
    "config = {\n",
    "    # Dataset-related parameters, feel free to modify these paths according to your dataset structure\n",
    "    \"image_dir_path\": f\"{project_dir}/dataset/multiplex_images/\",  # Path to the multiplex image file\n",
    "    \"cell_mask_dir_path\": f\"{project_dir}/dataset/cell_masks/\",  # Path to the cell mask file\n",
    "    \"marker_info_with_metadata_csv_path\": f\"{project_dir}/dataset/marker_info_with_metadata.csv\",  # Path to the marker metadata CSV file\n",
    "    \"gt_csv_path\": f\"{project_dir}/dataset/cell_annotations.csv\",  # Path to the ground truth annotations CSV file\n",
    "    \"num_classes\": 16,  # Number of cell types in the dataset\n",
    "    \n",
    "    # Output directories for intermediate and final results\n",
    "    \"patch_dir\": f\"{project_dir}/patches/\",  # Directory to save extracted patches\n",
    "    \"feature_dir\": f\"{project_dir}/features/\",  # Directory to save extracted features\n",
    "    \"fold_dir\": f\"{project_dir}/folds/\",  # Directory to save cross-validation folds\n",
    "    \"result_dir\": f\"{project_dir}/results/\",  # Directory to save final results\n",
    "    \n",
    "    # Model-related parameters\n",
    "    \"checkpoint_path\": \"hf_hub:MahmoodLab/kronos\",  # Path to the pre-trained model checkpoint (local or Hugging Face Hub)\n",
    "    \"hf_auth_token\": None,  # Authentication token for Hugging Face Hub (if checkpoint is from the Hugging Face Hub)\n",
    "    \"cache_dir\": f\"{project_dir}/models/\",  # Directory to cache KRONOS model if downloading model from Hugging Face Hub\n",
    "    \"model_type\": \"vits16\",  # Type of pre-trained model to use (e.g., vits16)\n",
    "    \"token_overlap\": True,  # Whether to use token overlap during feature extraction\n",
    "\n",
    "    # Patch extraction parameters\n",
    "    \"patch_size\": 32,  # Size of the patches to extract (e.g., 64x64 pixels)\n",
    "    \"patch_stride\": 16,  # Stride for patch extraction (e.g., 32 pixels)\n",
    "    \"token_size\": 16,  # Size of the tokens for the model (e.g., 16x16 pixels)\n",
    "    \"token_stride\": 8,  # Stride for token extraction (e.g., 16 or 8 pixels)\n",
    "\n",
    "    # Feature extraction parameters\n",
    "    \"marker_list\": ['DAPI-01', 'CD11B', 'CD11C', 'CD15', 'CD163', 'CD20', 'CD206', 'CD30', 'CD31', 'CD4', 'CD56', 'CD68', 'CD7', 'CD8', 'CYTOKERITIN', 'FOXP3', 'MCT', 'PODOPLANIN'],  # List of markers to be used for patch phenotyping\n",
    "    \"marker_max_values\": 65535.0,  # Maximum possible value marker image, depends on image type (e.g., 255 for uint8, 65535 for uint16)\n",
    "    \"patch_batch_size\": 16,  # Batch size for patch-based data loading\n",
    "    \"num_workers\": 4,  # Number of workers for data loading\n",
    "\n",
    "    # Logistic regression parameters\n",
    "    \"max_patches_per_type\": 1000,  # Maximum number of patches per phenotype for training, set to None for no limit\n",
    "    \"feature_batch_size\": 256,  # Batch size for feature-based data loading\n",
    "    \"n_trials\": 25,  # Number of trials for Optuna hyperparameter optimization\n",
    "    \"C_low\": 1e-10,  # Lower bound for the regularization parameter (C) in logistic regression\n",
    "    \"C_high\": 1e5,  # Upper bound for the regularization parameter (C) in logistic regression\n",
    "    \"max_iter\": 10000,  # Maximum number of iterations for logistic regression training\n",
    "}\n",
    "obj = PatchPhenotyping(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Patch Extraction\n",
    "\n",
    "Extract patches from multiplex images using cell masks and annotations. Each patch is stored in an HDF5 (`.h5`) file with the following structure:\n",
    "                - A cell segmentation mask under the key \"cell_seg_map\".\n",
    "                - A cell type map under the key \"cell_type_map\".\n",
    "                - Individual marker datasets for each marker in the patch.\n",
    "- **Cell segmentation map**: Stored under the key `cell_seg_map`, where each non-zero pixel represent cell id and 0 represents the non-cell region.\n",
    "- **Cell type map**: Stored under the key `cell_type_map`, where each non-zero pixel represent cell type and 0 represents the non-cell region.\n",
    "- **Marker datasets**: Each marker dataset contains intensity values corresponding to a specific marker within the patch.\n",
    "\n",
    "**File Naming Convention**: `{image_name}_{x}_{y}.h5`  \n",
    "Example: `\"image_01_000000_000032.h5\"` represents patch who top-left corner is at [0,32] in the image `\"image_01.tiff\"`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.patch_extraction()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Feature Extraction\n",
    "\n",
    "Extract features from patches using a pre-trained model. This step processes patches of multiplex images, applies a pre-trained model to extract meaningful feature vectors, and saves them for downstream analysis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.feature_extraction()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Generate Data Folds\n",
    "\n",
    "Generate data folds for training, validation, and testing. This tutorial uses a cHL dataset, dividing a large visual field into four quadrants. Each quadrant represents one fold."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.folds_generation()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5: Train Cell Phenotyping Model\n",
    "\n",
    "Train a logistic regression model on the training data, validate it on the validation data, and evaluate it on the test data. Results for each fold are saved in the output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.train_classification_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6: Evaluate Model\n",
    "\n",
    "Read the test results for each fold, compute evaluation metrics (F1-Score, Balanced Accuracy, Average Precision, and ROC AUC), and calculate the average and standard deviation across all folds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.eval_classification_model()\n",
    "obj.calculate_results()\n",
    "obj.plot_results()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 7: Visual Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.visual_results()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kronos_v2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
